{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2019-06-03 09:15:48--  https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip\n",
      "Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.24.112, 2404:6800:4003:c04::80\n",
      "Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.24.112|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 407727028 (389M) [application/zip]\n",
      "Saving to: ‘uncased_L-12_H-768_A-12.zip’\n",
      "\n",
      "uncased_L-12_H-768_ 100%[===================>] 388.84M  69.0MB/s    in 5.6s    \n",
      "\n",
      "2019-06-03 09:15:55 (69.0 MB/s) - ‘uncased_L-12_H-768_A-12.zip’ saved [407727028/407727028]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Archive:  uncased_L-12_H-768_A-12.zip\n",
      "   creating: uncased_L-12_H-768_A-12/\n",
      "  inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.meta  \n",
      "  inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001  \n",
      "  inflating: uncased_L-12_H-768_A-12/vocab.txt  \n",
      "  inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.index  \n",
      "  inflating: uncased_L-12_H-768_A-12/bert_config.json  \n"
     ]
    }
   ],
   "source": [
    "!unzip uncased_L-12_H-768_A-12.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "from utils import *\n",
    "import tensorflow as tf\n",
    "from sklearn.cross_validation import train_test_split\n",
    "import time\n",
    "import random\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['negative', 'positive']\n",
      "10662\n",
      "10662\n"
     ]
    }
   ],
   "source": [
    "trainset = sklearn.datasets.load_files(container_path = 'data', encoding = 'UTF-8')\n",
    "trainset.data, trainset.target = separate_dataset(trainset,1.0)\n",
    "print (trainset.target_names)\n",
    "print (len(trainset.data))\n",
    "print (len(trainset.target))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "BERT_VOCAB = 'uncased_L-12_H-768_A-12/vocab.txt'\n",
    "BERT_INIT_CHKPNT = 'uncased_L-12_H-768_A-12/bert_model.ckpt'\n",
    "BERT_CONFIG = 'uncased_L-12_H-768_A-12/bert_config.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import bert\n",
    "from bert import run_classifier\n",
    "from bert import optimization\n",
    "from bert import tokenization\n",
    "from bert import modeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenization.validate_case_matches_checkpoint(True,BERT_INIT_CHKPNT)\n",
    "tokenizer = tokenization.FullTokenizer(\n",
    "      vocab_file=BERT_VOCAB, do_lower_case=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_SEQ_LENGTH = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10662/10662 [00:03<00:00, 3511.65it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "input_ids, input_masks, segment_ids = [], [], []\n",
    "\n",
    "for text in tqdm(trainset.data):\n",
    "    tokens_a = tokenizer.tokenize(text)\n",
    "    if len(tokens_a) > MAX_SEQ_LENGTH - 2:\n",
    "        tokens_a = tokens_a[:(MAX_SEQ_LENGTH - 2)]\n",
    "    tokens = [\"[CLS]\"] + tokens_a + [\"[SEP]\"]\n",
    "    segment_id = [0] * len(tokens)\n",
    "    input_id = tokenizer.convert_tokens_to_ids(tokens)\n",
    "    input_mask = [1] * len(input_id)\n",
    "    padding = [0] * (MAX_SEQ_LENGTH - len(input_id))\n",
    "    input_id += padding\n",
    "    input_mask += padding\n",
    "    segment_id += padding\n",
    "    \n",
    "    input_ids.append(input_id)\n",
    "    input_masks.append(input_mask)\n",
    "    segment_ids.append(segment_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['production',\n",
       " 'design',\n",
       " 'score',\n",
       " 'choreography',\n",
       " 'simply',\n",
       " 'into',\n",
       " '##xi',\n",
       " '##cating']"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.tokenize(trainset.data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 5\n",
    "batch_size = 3\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = int(len(input_ids) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dimension_output,\n",
    "        learning_rate = 2e-5,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None])\n",
    "        \n",
    "        model = modeling.BertModel(\n",
    "            config=bert_config,\n",
    "            is_training=True,\n",
    "            input_ids=self.X,\n",
    "            input_mask=self.input_masks,\n",
    "            token_type_ids=self.segment_ids,\n",
    "            use_one_hot_embeddings=False)\n",
    "        \n",
    "        output_layer = model.get_pooled_output()\n",
    "        self.logits = tf.layers.dense(output_layer, dimension_output)\n",
    "        \n",
    "        self.cost = tf.reduce_mean(\n",
    "            tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "                logits = self.logits, labels = self.Y\n",
    "            )\n",
    "        )\n",
    "        \n",
    "        self.optimizer = optimization.create_optimizer(self.cost, learning_rate, \n",
    "                                                       num_train_steps, num_warmup_steps, False)\n",
    "        correct_pred = tf.equal(\n",
    "            tf.argmax(self.logits, 1, output_type = tf.int32), self.Y\n",
    "        )\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from uncased_L-12_H-768_A-12/bert_model.ckpt\n"
     ]
    }
   ],
   "source": [
    "dimension_output = 2\n",
    "learning_rate = 2e-5\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    "    learning_rate\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'bert')\n",
    "saver = tf.train.Saver(var_list = var_lists)\n",
    "saver.restore(sess, BERT_INIT_CHKPNT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_input_ids, test_input_ids, train_input_masks, test_input_masks, train_segment_ids, test_segment_ids, train_Y, test_Y = train_test_split(\n",
    "    input_ids, input_masks, segment_ids, trainset.target, test_size = 0.2\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2843/2843 [14:02<00:00,  3.37it/s, accuracy=0.667, cost=1.69] \n",
      "test minibatch loop: 100%|██████████| 711/711 [00:44<00:00, 16.31it/s, accuracy=0.333, cost=0.995]\n",
      "train minibatch loop:   0%|          | 0/2843 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.000000, current acc: 0.728551\n",
      "time taken: 886.7454442977905\n",
      "epoch: 0, training loss: 0.689427, training acc: 0.724469, valid loss: 0.932971, valid acc: 0.728551\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2843/2843 [14:01<00:00,  3.38it/s, accuracy=0.667, cost=2.02] \n",
      "test minibatch loop: 100%|██████████| 711/711 [00:43<00:00, 16.22it/s, accuracy=0.667, cost=2.23] \n",
      "train minibatch loop:   0%|          | 0/2843 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, pass acc: 0.728551, current acc: 0.794187\n",
      "time taken: 885.3626654148102\n",
      "epoch: 1, training loss: 0.524612, training acc: 0.871263, valid loss: 0.931003, valid acc: 0.794187\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2843/2843 [14:03<00:00,  3.37it/s, accuracy=0.667, cost=0.367]\n",
      "test minibatch loop: 100%|██████████| 711/711 [00:43<00:00, 16.20it/s, accuracy=0.333, cost=2.61] \n",
      "train minibatch loop:   0%|          | 0/2843 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, pass acc: 0.794187, current acc: 0.798406\n",
      "time taken: 886.9451875686646\n",
      "epoch: 2, training loss: 0.276080, training acc: 0.944894, valid loss: 1.174319, valid acc: 0.798406\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2843/2843 [14:00<00:00,  3.39it/s, accuracy=1, cost=7.8e-5]   \n",
      "test minibatch loop: 100%|██████████| 711/711 [00:43<00:00, 16.23it/s, accuracy=0.667, cost=2.19] \n",
      "train minibatch loop:   0%|          | 0/2843 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 3, pass acc: 0.798406, current acc: 0.798875\n",
      "time taken: 884.7811346054077\n",
      "epoch: 3, training loss: 0.126057, training acc: 0.977840, valid loss: 1.651340, valid acc: 0.798875\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2843/2843 [14:03<00:00,  3.37it/s, accuracy=1, cost=2.63e-5]  \n",
      "test minibatch loop: 100%|██████████| 711/711 [00:43<00:00, 16.16it/s, accuracy=0.667, cost=1.91] \n",
      "train minibatch loop:   0%|          | 0/2843 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4, pass acc: 0.798875, current acc: 0.802625\n",
      "time taken: 887.2457585334778\n",
      "epoch: 4, training loss: 0.064315, training acc: 0.989682, valid loss: 1.709754, valid acc: 0.802625\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2843/2843 [14:03<00:00,  3.37it/s, accuracy=1, cost=1.78e-5]  \n",
      "test minibatch loop: 100%|██████████| 711/711 [00:44<00:00, 16.10it/s, accuracy=0.667, cost=2.75] \n",
      "train minibatch loop:   0%|          | 0/2843 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 5, pass acc: 0.802625, current acc: 0.811064\n",
      "time taken: 887.5904607772827\n",
      "epoch: 5, training loss: 0.032120, training acc: 0.994607, valid loss: 1.797124, valid acc: 0.811064\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2843/2843 [14:09<00:00,  3.37it/s, accuracy=1, cost=2.14e-5]  \n",
      "test minibatch loop: 100%|██████████| 711/711 [00:44<00:00, 16.06it/s, accuracy=0.333, cost=3.31] \n",
      "train minibatch loop:   0%|          | 0/2843 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 893.3844969272614\n",
      "epoch: 6, training loss: 0.020082, training acc: 0.996834, valid loss: 1.802879, valid acc: 0.804970\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2843/2843 [14:08<00:00,  3.37it/s, accuracy=1, cost=2.26e-5]  \n",
      "test minibatch loop: 100%|██████████| 711/711 [00:43<00:00, 16.20it/s, accuracy=0.333, cost=5.04] \n",
      "train minibatch loop:   0%|          | 0/2843 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 892.2061610221863\n",
      "epoch: 7, training loss: 0.018238, training acc: 0.997303, valid loss: 1.836484, valid acc: 0.802157\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 2843/2843 [14:02<00:00,  3.37it/s, accuracy=1, cost=2.18e-5]  \n",
      "test minibatch loop: 100%|██████████| 711/711 [00:43<00:00, 16.25it/s, accuracy=0.667, cost=2.97] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 886.4455940723419\n",
      "epoch: 8, training loss: 0.018759, training acc: 0.996834, valid loss: 1.801667, valid acc: 0.809189\n",
      "\n",
      "break epoch:9\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import time\n",
    "\n",
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 3, 0, 0, 0\n",
    "\n",
    "while True:\n",
    "    lasttime = time.time()\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:%d\\n' % (EPOCH))\n",
    "        break\n",
    "\n",
    "    train_acc, train_loss, test_acc, test_loss = 0, 0, 0, 0\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_input_ids), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_input_ids))\n",
    "        batch_x = train_input_ids[i: index]\n",
    "        batch_masks = train_input_masks[i: index]\n",
    "        batch_segment = train_segment_ids[i: index]\n",
    "        batch_y = train_Y[i: index]\n",
    "        acc, cost, _ = sess.run(\n",
    "            [model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.segment_ids: batch_segment,\n",
    "                model.input_masks: batch_masks\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        train_loss += cost\n",
    "        train_acc += acc\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "\n",
    "    pbar = tqdm(range(0, len(test_input_ids), batch_size), desc = 'test minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_input_ids))\n",
    "        batch_x = test_input_ids[i: index]\n",
    "        batch_masks = test_input_masks[i: index]\n",
    "        batch_segment = test_segment_ids[i: index]\n",
    "        batch_y = test_Y[i: index]\n",
    "        acc, cost = sess.run(\n",
    "            [model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.segment_ids: batch_segment,\n",
    "                model.input_masks: batch_masks\n",
    "            },\n",
    "        )\n",
    "        test_loss += cost\n",
    "        test_acc += acc\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "\n",
    "    train_loss /= len(train_input_ids) / batch_size\n",
    "    train_acc /= len(train_input_ids) / batch_size\n",
    "    test_loss /= len(test_input_ids) / batch_size\n",
    "    test_acc /= len(test_input_ids) / batch_size\n",
    "\n",
    "    if test_acc > CURRENT_ACC:\n",
    "        print(\n",
    "            'epoch: %d, pass acc: %f, current acc: %f'\n",
    "            % (EPOCH, CURRENT_ACC, test_acc)\n",
    "        )\n",
    "        CURRENT_ACC = test_acc\n",
    "        CURRENT_CHECKPOINT = 0\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "        \n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (EPOCH, train_loss, train_acc, test_loss, test_acc)\n",
    "    )\n",
    "    EPOCH += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
